{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## crnn ocr 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "from torch.autograd import Variable\n",
    "import torch.backends.cudnn as cudnn\n",
    "import torch.optim as optim\n",
    "import torch.utils.data\n",
    "import numpy as np\n",
    "from warpctc_pytorch import CTCLoss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 创建数据软连接"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ln -s /home/lywen/data/ocr ../data/ocr/1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../../')\n",
    "from train.ocr.dataset import PathDataset,randomSequentialSampler,alignCollate\n",
    "from glob import glob\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "roots = glob('./train/data/ocr/*/*.jpg')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练字符集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "alphabetChinese = '1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainP,testP = train_test_split(roots,test_size=0.1)##此处未考虑字符平衡划分\n",
    "traindataset = PathDataset(trainP,alphabetChinese)\n",
    "testdataset = PathDataset(testP,alphabetChinese)\n",
    "\n",
    "batchSize = 32\n",
    "workers = 1\n",
    "imgH = 32\n",
    "imgW = 280\n",
    "keep_ratio = True\n",
    "cuda = True\n",
    "ngpu = 1\n",
    "nh =256\n",
    "sampler = randomSequentialSampler(traindataset, batchSize)\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    traindataset, batch_size=batchSize,\n",
    "    shuffle=False, sampler=None,\n",
    "    num_workers=int(workers),\n",
    "    collate_fn=alignCollate(imgH=imgH, imgW=imgW, keep_ratio=keep_ratio))\n",
    "\n",
    "train_iter = iter(train_loader)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载预训练模型权重"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def weights_init(m):\n",
    "    classname = m.__class__.__name__\n",
    "    if classname.find('Conv') != -1:\n",
    "        m.weight.data.normal_(0.0, 0.02)\n",
    "    elif classname.find('BatchNorm') != -1:\n",
    "        m.weight.data.normal_(1.0, 0.02)\n",
    "        m.bias.data.fill_(0)\n",
    "        \n",
    "from crnn.models.crnn import CRNN\n",
    "from config import ocrModel,LSTMFLAG,GPU\n",
    "model = CRNN(32, 1, len(alphabetChinese)+1, 256, 1,lstmFlag=LSTMFLAG)\n",
    "model.apply(weights_init)\n",
    "preWeightDict = torch.load(ocrModel,map_location=lambda storage, loc: storage)##加入项目训练的权重\n",
    "\n",
    "modelWeightDict = model.state_dict()\n",
    "\n",
    "for k, v in preWeightDict.items():\n",
    "            name = k.replace('module.','') # remove `module.`\n",
    "            if  'rnn.1.embedding' not in name:##不加载最后一层权重\n",
    "                 modelWeightDict[name] = v\n",
    "            \n",
    "model.load_state_dict(modelWeightDict)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CRNN(\n",
       "  (cnn): Sequential(\n",
       "    (conv0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (relu0): ReLU(inplace)\n",
       "    (pooling0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (relu1): ReLU(inplace)\n",
       "    (pooling1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (batchnorm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (relu2): ReLU(inplace)\n",
       "    (conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (relu3): ReLU(inplace)\n",
       "    (pooling2): MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1), dilation=1, ceil_mode=False)\n",
       "    (conv4): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (batchnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (relu4): ReLU(inplace)\n",
       "    (conv5): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (relu5): ReLU(inplace)\n",
       "    (pooling3): MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1), dilation=1, ceil_mode=False)\n",
       "    (conv6): Conv2d(512, 512, kernel_size=(2, 2), stride=(1, 1))\n",
       "    (batchnorm6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (relu6): ReLU(inplace)\n",
       "  )\n",
       "  (rnn): Sequential(\n",
       "    (0): BidirectionalLSTM(\n",
       "      (rnn): LSTM(512, 256, bidirectional=True)\n",
       "      (embedding): Linear(in_features=512, out_features=256, bias=True)\n",
       "    )\n",
       "    (1): BidirectionalLSTM(\n",
       "      (rnn): LSTM(256, 256, bidirectional=True)\n",
       "      (embedding): Linear(in_features=512, out_features=63, bias=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "##优化器\n",
    "from crnn.util import strLabelConverter\n",
    "lr = 0.1\n",
    "optimizer = optim.Adadelta(model.parameters(), lr=lr)\n",
    "converter = strLabelConverter(''.join(alphabetChinese))\n",
    "criterion = CTCLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from train.ocr.dataset import resizeNormalize\n",
    "from crnn.util import loadData\n",
    "image = torch.FloatTensor(batchSize, 3, imgH, imgH)\n",
    "text = torch.IntTensor(batchSize * 5)\n",
    "length = torch.IntTensor(batchSize)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model.cuda()\n",
    "    model = torch.nn.DataParallel(model, device_ids=[0])##转换为多GPU训练模型\n",
    "    image = image.cuda()\n",
    "    criterion = criterion.cuda()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def trainBatch(net, criterion, optimizer,cpu_images, cpu_texts):\n",
    "    #data = train_iter.next()\n",
    "    #cpu_images, cpu_texts = data\n",
    "    batch_size = cpu_images.size(0)\n",
    "    loadData(image, cpu_images)\n",
    "    t, l = converter.encode(cpu_texts)\n",
    "    \n",
    "    loadData(text, t)\n",
    "    loadData(length, l)\n",
    "    preds = net(image)\n",
    "    preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))\n",
    "    cost = criterion(preds, text, preds_size, length) / batch_size\n",
    "    net.zero_grad()\n",
    "    cost.backward()\n",
    "    optimizer.step()\n",
    "    return cost\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def predict(im):\n",
    "    \"\"\"\n",
    "    预测\n",
    "    \"\"\"\n",
    "    image = im.convert('L')\n",
    "    scale = image.size[1]*1.0 / 32\n",
    "    w = image.size[0] / scale\n",
    "    w = int(w)\n",
    "    transformer = resizeNormalize((w, 32))\n",
    "    \n",
    "    image = transformer(image)\n",
    "    if torch.cuda.is_available():\n",
    "        image = image.cuda()\n",
    "    image = image.view(1, *image.size())\n",
    "    image = Variable(image)\n",
    "    preds = model(image)\n",
    "    _, preds = preds.max(2)\n",
    "    preds = preds.transpose(1, 0).contiguous().view(-1)\n",
    "    preds_size = Variable(torch.IntTensor([preds.size(0)]))\n",
    "    sim_pred = converter.decode(preds.data, preds_size.data, raw=False)\n",
    "    return sim_pred\n",
    "   \n",
    "   \n",
    "def val(net, dataset, max_iter=100):\n",
    "\n",
    "    for p in net.parameters():\n",
    "        p.requires_grad = False\n",
    "    net.eval()\n",
    "    i = 0\n",
    "    n_correct = 0\n",
    "    N = len(dataset)\n",
    "    \n",
    "    max_iter = min(max_iter, N)\n",
    "    for i in range(max_iter):\n",
    "        im,label = dataset[np.random.randint(0,N)]\n",
    "        if im.size[0]>1024:\n",
    "            continue\n",
    "        \n",
    "        pred = predict(im)\n",
    "        if pred.strip() ==label:\n",
    "                n_correct += 1\n",
    "\n",
    "    accuracy = n_correct / float(max_iter )\n",
    "    return accuracy\n",
    "    \n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from train.ocr.generic_utils import Progbar\n",
    "##进度条参考 https://github.com/keras-team/keras/blob/master/keras/utils/generic_utils.py\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 冻结预训练模型层参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:0/10\n",
      "   3/1407 [..............................] - ETA: 1:28 - loss: 20.9233 - acc: 0.0000e+00"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Exception ignored in: <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7f9c2a3e6a58>>\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/lywen/anaconda3/envs/chineseocr/lib/python3.6/site-packages/torch/utils/data/dataloader.py\", line 399, in __del__\n",
      "    self._shutdown_workers()\n",
      "  File \"/home/lywen/anaconda3/envs/chineseocr/lib/python3.6/site-packages/torch/utils/data/dataloader.py\", line 378, in _shutdown_workers\n",
      "    self.worker_result_queue.get()\n",
      "  File \"/home/lywen/anaconda3/envs/chineseocr/lib/python3.6/multiprocessing/queues.py\", line 344, in get\n",
      "    return _ForkingPickler.loads(res)\n",
      "  File \"/home/lywen/anaconda3/envs/chineseocr/lib/python3.6/site-packages/torch/multiprocessing/reductions.py\", line 151, in rebuild_storage_fd\n",
      "    fd = df.detach()\n",
      "  File \"/home/lywen/anaconda3/envs/chineseocr/lib/python3.6/multiprocessing/resource_sharer.py\", line 58, in detach\n",
      "    return reduction.recv_handle(conn)\n",
      "  File \"/home/lywen/anaconda3/envs/chineseocr/lib/python3.6/multiprocessing/reduction.py\", line 182, in recv_handle\n",
      "    return recvfds(s, 1)[0]\n",
      "  File \"/home/lywen/anaconda3/envs/chineseocr/lib/python3.6/multiprocessing/reduction.py\", line 153, in recvfds\n",
      "    msg, ancdata, flags, addr = sock.recvmsg(1, socket.CMSG_LEN(bytes_size))\n",
      "ConnectionResetError: [Errno 104] Connection reset by peer\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1407/1407 [==============================] - 56s 40ms/step - loss: 5.2403 - acc: 0.2484\n",
      "epoch:1/10\n",
      "1407/1407 [==============================] - 57s 41ms/step - loss: 0.0821 - acc: 0.7834\n",
      "epoch:2/10\n",
      "1407/1407 [==============================] - 57s 40ms/step - loss: 0.0253 - acc: 0.9102\n",
      "epoch:3/10\n",
      "1407/1407 [==============================] - 58s 41ms/step - loss: 0.0173 - acc: 0.9287\n",
      "epoch:4/10\n",
      "1407/1407 [==============================] - 57s 40ms/step - loss: 0.0139 - acc: 0.9414\n",
      "epoch:5/10\n",
      "1407/1407 [==============================] - 57s 40ms/step - loss: 0.0120 - acc: 0.9478\n",
      "epoch:6/10\n",
      "1407/1407 [==============================] - 58s 41ms/step - loss: 0.0108 - acc: 0.9570\n",
      "epoch:7/10\n",
      "1407/1407 [==============================] - 58s 41ms/step - loss: 0.0100 - acc: 0.9600\n",
      "epoch:8/10\n",
      "1407/1407 [==============================] - 57s 40ms/step - loss: 0.0094 - acc: 0.9600\n",
      "epoch:9/10\n",
      "1407/1407 [==============================] - 57s 40ms/step - loss: 0.0089 - acc: 0.9600\n"
     ]
    }
   ],
   "source": [
    "nepochs = 10\n",
    "acc  = 0\n",
    "\n",
    "interval = len(train_loader)//2##评估模型\n",
    "\n",
    "            \n",
    "for i in range(nepochs):\n",
    "    print('epoch:{}/{}'.format(i,nepochs))\n",
    "    n = len(train_loader)\n",
    "    pbar = Progbar(target=n)\n",
    "    train_iter = iter(train_loader)\n",
    "    loss = 0\n",
    "    for j in range(n):\n",
    "        for p in model.named_parameters():\n",
    "            p[1].requires_grad = True\n",
    "            if 'rnn.1.embedding' in p[0]:\n",
    "               p[1].requires_grad = True\n",
    "            else:\n",
    "                p[1].requires_grad = False##冻结模型层\n",
    "\n",
    "        model.train()\n",
    "        cpu_images, cpu_texts = train_iter.next()\n",
    "        cost = trainBatch(model, criterion, optimizer,cpu_images, cpu_texts)\n",
    "\n",
    "        loss += cost.data.numpy()\n",
    "        \n",
    "        if (j+1)%interval==0:\n",
    "            curAcc = val(model, testdataset, max_iter=1024)\n",
    "            if curAcc>acc:\n",
    "                acc = curAcc\n",
    "                torch.save(model.state_dict(), 'train/ocr/modellstm.pth')\n",
    "            \n",
    "               \n",
    "        pbar.update(j+1,values=[('loss',loss/((j+1)*batchSize)),('acc',acc)])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 释放模型层参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:10/10\n",
      "1407/1407 [==============================] - 151s 107ms/step - loss: 0.0061 - acc: 0.9766\n",
      "epoch:11/10\n",
      "1407/1407 [==============================] - 152s 108ms/step - loss: 0.0028 - acc: 0.9814\n",
      "epoch:12/10\n",
      "1407/1407 [==============================] - 152s 108ms/step - loss: 0.0017 - acc: 0.9844\n",
      "epoch:13/10\n",
      "1407/1407 [==============================] - 152s 108ms/step - loss: 0.0011 - acc: 0.9844\n",
      "epoch:14/10\n",
      "1407/1407 [==============================] - 152s 108ms/step - loss: 7.6338e-04 - acc: 0.9844\n",
      "epoch:15/10\n",
      "1407/1407 [==============================] - 152s 108ms/step - loss: 6.4547e-04 - acc: 0.9844\n",
      "epoch:16/10\n",
      " 789/1407 [===============>..............] - ETA: 1:06 - loss: 4.9605e-04 - acc: 0.9847- ETA: 1:09 -"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Process Process-18:\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/lywen/anaconda3/envs/chineseocr/lib/python3.6/multiprocessing/process.py\", line 249, in _bootstrap\n",
      "    self.run()\n",
      "  File \"/home/lywen/anaconda3/envs/chineseocr/lib/python3.6/multiprocessing/process.py\", line 93, in run\n",
      "    self._target(*self._args, **self._kwargs)\n",
      "  File \"/home/lywen/anaconda3/envs/chineseocr/lib/python3.6/site-packages/torch/utils/data/dataloader.py\", line 106, in _worker_loop\n",
      "    samples = collate_fn([dataset[i] for i in batch_indices])\n",
      "  File \"/home/lywen/anaconda3/envs/chineseocr/lib/python3.6/site-packages/torch/utils/data/dataloader.py\", line 106, in <listcomp>\n",
      "    samples = collate_fn([dataset[i] for i in batch_indices])\n",
      "  File \"/home/lywen/Desktop/2019/project/chineseocr/train/ocr/dataset.py\", line 35, in __getitem__\n",
      "    im = Image.open(imP).convert('L')\n",
      "  File \"/home/lywen/anaconda3/envs/chineseocr/lib/python3.6/site-packages/PIL/Image.py\", line 899, in convert\n",
      "    self.load()\n",
      "  File \"/home/lywen/anaconda3/envs/chineseocr/lib/python3.6/site-packages/PIL/ImageFile.py\", line 239, in load\n",
      "    n, err_code = decoder.decode(b)\n",
      "KeyboardInterrupt\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-12-88769ddb9ef2>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     18\u001b[0m         \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m         \u001b[0mcpu_images\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcpu_texts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_iter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m         \u001b[0mcost\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainBatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcpu_images\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcpu_texts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     22\u001b[0m         \u001b[0mloss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mcost\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-9-57d40fb384ae>\u001b[0m in \u001b[0;36mtrainBatch\u001b[0;34m(net, criterion, optimizer, cpu_images, cpu_texts)\u001b[0m\n\u001b[1;32m      4\u001b[0m     \u001b[0;31m#cpu_images, cpu_texts = data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0mbatch_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcpu_images\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m     \u001b[0mloadData\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcpu_images\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      7\u001b[0m     \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconverter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcpu_texts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Desktop/2019/project/chineseocr/crnn/util.py\u001b[0m in \u001b[0;36mloadData\u001b[0;34m(v, data)\u001b[0m\n\u001b[1;32m     87\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     88\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mloadData\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 89\u001b[0;31m     \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresize_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     90\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "nepochs = 10\n",
    "#acc  = 0\n",
    "\n",
    "interval = len(train_loader)//2##评估模型\n",
    "\n",
    "            \n",
    "for i in range(10,10+nepochs):\n",
    "    print('epoch:{}/{}'.format(i,nepochs))\n",
    "    n = len(train_loader)\n",
    "    pbar = Progbar(target=n)\n",
    "    train_iter = iter(train_loader)\n",
    "    loss = 0\n",
    "    for j in range(n):\n",
    "        for p in model.named_parameters():\n",
    "            p[1].requires_grad = True\n",
    "\n",
    "\n",
    "        model.train()\n",
    "        cpu_images, cpu_texts = train_iter.next()\n",
    "        cost = trainBatch(model, criterion, optimizer,cpu_images, cpu_texts)\n",
    "\n",
    "        loss += cost.data.numpy()\n",
    "        \n",
    "        if (j+1)%interval==0:\n",
    "            curAcc = val(model, testdataset, max_iter=1024)\n",
    "            if curAcc>acc:\n",
    "                acc = curAcc\n",
    "                torch.save(model.state_dict(), 'train/ocr/modellstm.pth')\n",
    "            \n",
    "               \n",
    "        pbar.update(j+1,values=[('loss',loss/((j+1)*batchSize)),('acc',acc)])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预测demo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataParallel(\n",
       "  (module): CRNN(\n",
       "    (cnn): Sequential(\n",
       "      (conv0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (relu0): ReLU(inplace)\n",
       "      (pooling0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (relu1): ReLU(inplace)\n",
       "      (pooling1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "      (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (batchnorm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu2): ReLU(inplace)\n",
       "      (conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (relu3): ReLU(inplace)\n",
       "      (pooling2): MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1), dilation=1, ceil_mode=False)\n",
       "      (conv4): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (batchnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu4): ReLU(inplace)\n",
       "      (conv5): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (relu5): ReLU(inplace)\n",
       "      (pooling3): MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1), dilation=1, ceil_mode=False)\n",
       "      (conv6): Conv2d(512, 512, kernel_size=(2, 2), stride=(1, 1))\n",
       "      (batchnorm6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu6): ReLU(inplace)\n",
       "    )\n",
       "    (rnn): Sequential(\n",
       "      (0): BidirectionalLSTM(\n",
       "        (rnn): LSTM(512, 256, bidirectional=True)\n",
       "        (embedding): Linear(in_features=512, out_features=256, bias=True)\n",
       "      )\n",
       "      (1): BidirectionalLSTM(\n",
       "        (rnn): LSTM(256, 256, bidirectional=True)\n",
       "        (embedding): Linear(in_features=512, out_features=63, bias=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "true:72390,pred:72390\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAFsAAAAeCAAAAABmyD2aAAAI20lEQVR4nF3QaYxV5RnA8Xc771nu\nvs9cBmaGZRARhqksQUWQKIUKbgQaixpt7YKitmlSTUobS5vUtvpBapPGtKaxkiAanSgFK1SqoEAF\nLCiFWZDZ7527z73n3LO9Sz8ItJnn2//D88uTB04gjZFiBLhUCEVyamqYAQIYgBACAKQEXw3HCEoJ\nwNWWihUT5ZAA0yZY1bqqIFoDAMBxLgGINTGEQEoJAJKcAQS4AuBXxpUdhXMJoAToSiM6pQAv4k/H\nRX5Jf8bzIk0AiMqMBiS+Q32CfUx86guMgUBIAgmu+AAAIISEGIlrlEdUjBTgKtPspMyHXcVqSgCQ\n5UhbyRsaxYRgQpDHoYKFgABIKa99BABfYgw451cbeRIi4MNpNLhMhWjiJEIIkRCaCkgPIYCFBJz7\nmgScMYSvsNdwDCFjAuOrHWhIgQQIOtPsqMSGyAUgBoA09OYMO2u5RPFAEHhBgYQnNYpMCCH8Hw0g\n4BJRShtXmiFdCODwaTQIVpgTUwKAAYA0d97ApO1THIKeIurMEVIS1c0xDiD6P1wwgVXYmLjaNqa1\nogDTzwZlFKMwJWtSSuKewdsM2ZZ+Ri2WzNJw/AaCpbRG+lZDCSGA12wkAZbVoYkNV5qQqX/bS6J0\n+uEqA0qNekkHAGLuvi+ml/MXYHOwOudStEtfag72HRyN7e0JrJpvB6jrx5vcTfh+7fWQ7xypH39w\noUtcqFhndyHRHXwomHZKuuaSUAGpjCq1oBkmXtRiQEIy9+TS4czmdzpOaDzsZAa1Zfyzx2devn5S\n5MnYU3PMfCt3G8H42F7Sf3BWlOVjx+j29qFW84WBHjhFq8cOvqQ5EUM6dCRFSxGAo+cH//LKTBnx\nAVPI6IZHP+xKuBVsZspciy4owMGM1+qgluHooZatgc5GkwaNyvnTl+NZVuMhZhwrPnSHue+D2Jdj\nnQ+3vXn8lafTvAqFlW3mO5vVw2OHp0RJOhmBhSCpXdtLwm6uqM5d/U5zInK6rRi1dfqDm7Sh5/ib\nF18yRQQ7jaP/yOPuW2Zmi/37z80e+2hZ4Z8oOnT/w2po9i8/PtetpRu6YCIxfPDc8ISODWv/ZsYp\nZsQ0VuQX/PGtjQOjK++7+Fh1zvC8uZ3lm7rjVsfeY0+cP7CaxBtO4HJ/28ue7hJ/Y/eOauprzljV\nWbqmQ2UytOiTF7Orvm/7kocnnxmbnE0IntpBjz7NlqhNpAWe2LP70iozucwt/03hfaCU2bf87R9P\nGqhu6yiigGozqV+o3O4zREA0bLU3E6tS5hSKtlIuQfH04sq5o/3RKIpaL56Ciyc1d8t3Urb97Nuj\nGCIX8+T4vDaUDKDRwwU7GZTt9o92LPstd2MnhLyMNA/hhjp/vqFKyOSp53O33i4nOAiKTq8e4Op3\nvZB64VXTc6zDxxfqRnbTgUceNlUzN/SFoEQyBbYXoGJa4X+da1mzTlMqSvgBxeQtRz6YMbKA8SAu\nhX6mBhgHUwdLrV0N4mqY+s7ghejMSRLOCQm1S8JLFfawnGE+O4/66ssHe8O5L7lCFCYCNRofCeqD\n7ySsO24ADLghjiC8+AKze7yQMKneTPiukssc26fPz65o6yy0RgLso/6tjH4ybFVItDTUu8U7WpE0\nszOm2SpJb936eEOYLTDHKAPI1WVpe53+sKfVUZ3IWGKi9u6nhYS1846xgBcDrhupvfvxADLy7NaW\n0IYZqKptH8e0ykN++p7rfzepLt5dGtqW7Pq5HoB6UVGR3fvazM03EYhVR0flZOnJZnnFWh8hk9QS\nzp/O2k2i32WUOx2Pcw36h16dG87BhP/ZlPHei+nsaFDBTTgrVYQLV/0EySiAz82QuWi8BJlmNIVx\n53P+iW4CCCecRPJvjMC5bc1UnkbISMtTGw8EorVt9zXpKIhblKOhvbGLTnZSuyWHBoJvPW0H/1wf\nzg3eRkKCX7IUdmdRT4zglQ1ucNHkwbqlJM2aRnzV1Bsose/d1mr9lgRUQw5uf8o9mmQ7zLtycS8C\nok2trrZuqg3PWvn2l7cld2e/5wBHmGrb7EXz8tjqeD1RmL3QZCvzg27XeCUKkib32i8mxtMGgZip\npvA+dsrLsjfCcsRM931Y+zxu7VoXGNOAi0m+php2eCsWTtvSRrYvNPr75W1doYqtuZEhEQ4f2e/R\nb7B08d5q4dCajRvqmlOZ18j3TugLISFQkQFxoc+dG/lNvhCyHAb/WkuK9aXeu+OsOD8P8dxCYiAZ\nd0JsCsPm7Cd//bl14yQxoxRSn8qTuy/RReu5oOxbzonTXu/qpeeLj/X9/X2a6hGEM2hr5rmClnp0\nMkybuDU/ZAN4954z92x7YPWcyyEhyztR3/1rBTKEq/njBz/nG4qxYqtVCZWThfFflWLGijZ3uMMD\ny4+0uM5b+zKF0fMobqfnWATXo9JRe2PFLqpyi6NK/IxK/PeV+uXknoEYWBTNMGWtdSCyb+dLG5fX\n87uota7NKGmOFVMc6V2spa3uDVpZg0xZ8NNc4czJRDE86rf0hB8pUKIxwXg5b3775lnYBxAEKlXM\nCpGw/l775i9OjNNe+MWnA3zN+Y6xC6cMJJzCmi2zbNVwApFJd/APx1OV1etDDQ1XOQwtXEw2Dp+w\nq/3kns2RhkKIDRimOamt7dCagiBoheKW6FmA9lw3vre0paWH8bO47pt96dfUQKF1Kr54y3LmhLET\nHAP8zZN+oWfTStdRpUslUBQlbSwN+H2F67J5lSgEqx7wxplGWdNTkIA4loHB278O7x35xVT0lPFN\nL7DOONS14o1MKFgOpLLb1Q5ZV4Srmmm//xC/d0nLYmAbkoc8X0BuS6VWD7SnqYUJBsTRfAKN21oo\nwhRxH/ts082aoYPQwtT+91PriQ1iD27j9Oz5wVlps0tKrnq6brop/T+Znu7o43EC8yzmM4ULokCG\ntHSVKYoLI66QEBBmGGWZ+DSMGPCFDlxfx3WUblRmOKN6diCNo6x/tp4DbW6DcH/WsOGDkGcbnj4W\njxTNDEQmiDOTqEwi7kpliqoUAy4gkP8FW6+osudGbA8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=91x30 at 0x7F9BB5650E48>"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "N  = len(testdataset)\n",
    "im,label = testdataset[np.random.randint(0,N)]\n",
    "pred = predict(im)\n",
    "print('true:{},pred:{}'.format(label,pred))\n",
    "im\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chineseocr",
   "language": "python",
   "name": "chineseocr"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
